from BasicLibraries import *
import functions as fn
from joblib import Parallel, delayed
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
from statsmodels.tools.sm_exceptions import ConvergenceWarning
warnings.simplefilter('ignore', ConvergenceWarning)

#=======
estimation_type = 'RLM'
#=======

class RegressionList():
    def __init__(self):
        pass
    def shape_data(self, em_type,tc,  golden_data, panel_data, controls_, dummies, initiative_set, scaler_type,  
                   mediating_controls = [], strategy_type = None,INTERACTIONTYPE = 'COMPETITOR', 
                   number_of_lags = 2, specific_emission_type=False, for_summary_statistics = False):
        
        
        controls = controls_.copy()
        dep1 = 'total_emissions'
        dep2 = 'initiatives'

        if strategy_type == None:
            res = fn.make_frame(golden_data, initiative_set= initiative_set)
            target_initiatives = initiative_set
        else:
            print('Looking at strategies, not initiatives')
            res, target_initiatives = fn.make_strategy(golden_data, strategy_type, scaler_type)
            dep2 = 'Strategy'
        res = res.reset_index().sort_values(by = ['ids', 'rfyear'])
        
        if strategy_type in ['Concentration', 'Entropy', 'EntropyOLD', 'Simpson'] or scaler_type == 'total':
            golden_controls = ['initiatives_tot']
        else:
            golden_controls = ['initiatives_tot']

        tcX = tc.copy()
        tcX = tcX.sort_values(by = ['ISIN', 'cyear'])
        emission_lags = []
        for lag in range(1, number_of_lags+1):
            tcX['em_type_lag_'+str(lag)] = tcX[[em_type, 'ISIN']].groupby('ISIN').shift(-lag)
            emission_lags.append('em_type_lag_'+str(lag))
        print('Emission lag:', len(emission_lags))
        tcX['total_emissions'] = tcX[emission_lags].sum(axis = 1, skipna=False).replace([0], [np.nan])
        emissions_type = ['total_emissions']
        tcX[emissions_type] = tcX[emissions_type].astype(float).apply(np.log)
        tcX['mrg'] = tcX['cyear'].astype(int).astype(str)+'-'+tcX['ISIN']
        
        res = res.reset_index().sort_values(by = ['ids', 'rfyear']).drop(columns = ['index'])
        res['check'] = res[['rfyear', 'gvkey']].groupby('gvkey').diff()
        res['check'] = res[['check', 'gvkey']].groupby('gvkey').shift(-1)
        res[[dep2]+golden_controls] = res[[dep2]+golden_controls+['ids']].groupby('ids').shift(-1)
        res = res[res.check <= 1]

        x = panel_data.merge(res[[dep2] + golden_controls + ['mrg']], on = 'mrg').merge(tcX[[dep1, 'Estimated', 'Reported', 'mrg']]) 
        controls = controls + [dep2]
        if strategy_type == None or scaler_type == 'total':
            if strategy_type not in ['Entropy', 'EntropyOLD', 'Concentration', 'Simpson']:
                x[dep2] = x[dep2]/x.investment

        if tc['Reported'].sum() == len(tc):
            print('only focusing on reported emissions')
            emission_dummies=[]
        else:
            print('Including emission source')
            emission_dummies = ['Estimated', 'Reported']

        x = x.reset_index().sort_values(by = ['ids', 'rfyear']).drop(columns = ['index'])
        if for_summary_statistics:
            sumstat = x[emissions_type+[INTERACTIONTYPE] + golden_controls+ ['IMR', 'mrg', 'ISIN', 'rfyear']+controls+mediating_controls+dummies+emission_dummies]
            
        x = x[emissions_type+[INTERACTIONTYPE] + golden_controls+ ['IMR', 'mrg', 'ISIN']+controls+mediating_controls+dummies+emission_dummies]
        x_ = x.copy()
        x_ =  x_[emissions_type+ [INTERACTIONTYPE] + golden_controls+ ['IMR', 'mrg', 'ISIN']+controls+mediating_controls+dummies+emission_dummies].set_index('ISIN')
        x_ = x_.replace([-np.inf, np.inf], [np.nan, np.nan]).dropna()
        x_[INTERACTIONTYPE] = x_[[INTERACTIONTYPE]].apply(scale)

        if for_summary_statistics:
            print('Only returning data for summary statistics purposes')
            return sumstat.dropna()
        else:
            if strategy_type != None:
                print(len(x_))
                x_[golden_controls] = x_[golden_controls].apply(np.log)
                return x_, dep1, dep2, controls+golden_controls, emission_dummies
            else:
                 return x_, dep1, dep2, controls, emission_dummies
         
    def shape_data_BASELINE(self, em_type,tc,  golden_data, panel_data, controls_, dummies, initiative_set, scaler_type,  
                   mediating_controls = [], strategy_type = None,INTERACTIONTYPE = 'COMPETITOR', 
                   number_of_lags = 2, specific_emission_type=False):
        
        mediating_controls = []
    
        controls = controls_.copy()
        golden_data['initiatives_tot'] = golden_data['number_of_initiatives']
        golden_data['mrg'] =golden_data['rfyear'].astype(int).astype(str)+'-'+golden_data['ids'] 
        golden_data = golden_data.sort_values(by = ['ids', 'rfyear'])
        golden_data['initiatives_tot'] = golden_data[['initiatives_tot', 'ids']].groupby('ids').shift(-1)
        golden_data['Entropy'] = golden_data[['Entropy', 'ids']].groupby('ids').shift(-1)

        dep1 = 'total_emissions'
        dep2 = 'initiatives'
    
        print('Looking at firm size, not a strategy')
        res  = panel_data[['firm_size_t1', 'mrg', 'ids', 'rfyear']]
        res['Strategy'] = res['firm_size_t1']
        dep2 = 'Strategy'
        res = res.reset_index().sort_values(by = ['ids', 'rfyear'])
        
        golden_controls = ['initiatives_tot', 'Entropy']
    
        tcX = tc.copy()
        tcX = tcX.sort_values(by = ['ISIN', 'cyear'])
        emission_lags = []
        for lag in range(1, number_of_lags+1):
            tcX['em_type_lag_'+str(lag)] = tcX[[em_type, 'ISIN']].groupby('ISIN').shift(-lag)
            emission_lags.append('em_type_lag_'+str(lag))
        print('Emission lag:', len(emission_lags))
        tcX['total_emissions'] = tcX[emission_lags].sum(axis = 1, skipna=False).replace([0], [np.nan])
        emissions_type = ['total_emissions']
        tcX[emissions_type] = tcX[emissions_type].astype(float).apply(np.log)
        tcX['mrg'] = tcX['cyear'].astype(int).astype(str)+'-'+tcX['ISIN']
        
    
        x = panel_data.merge(res[[dep2] + ['mrg']], on = 'mrg').merge(tcX[[dep1, 'Estimated', 'Reported', 'mrg']]) 
        x = x.merge(golden_data[['mrg']+ golden_controls])
        controls = ['at_usd', 'Tangibility', 'Profitability'] + golden_controls
    
        if tc['Reported'].sum() == len(tc):
            print('only focusing on reported emissions')
            emission_dummies=[]
        else:
            print('Including emission source')
            emission_dummies = ['Estimated', 'Reported']
    
        x = x.reset_index().sort_values(by = ['ids', 'rfyear']).drop(columns = ['index'])
        x = x[emissions_type+[INTERACTIONTYPE, dep2] + golden_controls+ ['IMR', 'mrg', 'ISIN']+controls+mediating_controls+dummies+emission_dummies]
        x_ = x.copy()
        x_ =  x_[emissions_type+ [INTERACTIONTYPE, dep2] + golden_controls+ ['IMR', 'mrg', 'ISIN']+controls+mediating_controls+dummies+emission_dummies].set_index('ISIN')
        x_ = x_.replace([-np.inf, np.inf], [np.nan, np.nan]).dropna()
        x_[INTERACTIONTYPE] = x_[[INTERACTIONTYPE]].apply(scale)
    
        
        print(len(x_))
        x_[golden_controls] = x_[golden_controls].apply(np.log)
    
        return x_, dep1, dep2, controls, emission_dummies
         
            
    def filter_outlier(self, X, var_to_filter = ['firm_size', 'investment', 'Tangibility',  'Profitability']):
        X[var_to_filter] = X[var_to_filter][X[var_to_filter] < X[var_to_filter].quantile(0.99)]
        X = X.loc[X[var_to_filter].dropna().index].reset_index(drop=True)
        return X

    def parallel_loop(self, iteration,dt_,mediator_var, Exposure, outcome_var, mid_controls, truncated_controls, dummies, emission_dummies):
        sample = np.random.choice(dt_.index, int(0.9*len(dt_)))
        dtB = dt_.iloc[sample]
        phi_controls = truncated_controls.copy()
        if 'firm_size_t1' in phi_controls:
            phi_controls.remove('firm_size_t1')
        mediator_ = sm.RLM(dtB[mediator_var], dtB[[Exposure]+phi_controls+['IMR']+dummies], M=sm.robust.norms.HuberT()).fit()
        outcome_ = sm.RLM(dtB[outcome_var], dtB[[Exposure, mediator_var]+truncated_controls+mid_controls+dummies+emission_dummies], M=sm.robust.norms.HuberT()).fit()
        partial_ = sm.RLM(dtB[outcome_var], dtB[[Exposure, mediator_var]+truncated_controls+dummies+emission_dummies], M=sm.robust.norms.HuberT()).fit() 

        alpha   =  mediator_.params[Exposure]*dtB[Exposure].std()/dtB[mediator_var].std()
        beta    =  outcome_.params[mediator_var]*dtB[mediator_var].std()/dtB[outcome_var].std()
        gamma   =  partial_.params[Exposure]*dtB[Exposure].std()/dtB[outcome_var].std()
        mediation = alpha*beta
        #=== Use unstandardised coefficients to estimate the proportion of the mediated effect
        alphaEFX  =  mediator_.params[Exposure]
        betaEFX   =  outcome_.params[mediator_var]
        gammaEFX  =  partial_.params[Exposure]
        proportion_of_mediation  = (alphaEFX*betaEFX)/(alphaEFX*betaEFX+gammaEFX)

        return alpha,beta,gamma,mediation,gamma+mediation, alphaEFX, betaEFX, gammaEFX, proportion_of_mediation, dtB.mean()
    
    def bootstrapped_parallelised(self, dt_X, outcome_var, mediator_var, Exposure, controls, mid_controls, dummies, emission_dummies):
        dt_ = dt_X.copy()
        truncated_controls = controls.copy()
        if mediator_var in truncated_controls:
            truncated_controls.remove(mediator_var)
        dt_ =  dt_.dropna().reset_index()
        dt_ = self.filter_outlier(dt_, var_to_filter = [Exposure, outcome_var]+truncated_controls) 
        dt_ = dt_.reset_index(drop=True)

        ################
        k_max = 1000     
        OUT = \
            Parallel(n_jobs=6)(delayed(self.parallel_loop)(i, dt_,mediator_var, Exposure, outcome_var, mid_controls, truncated_controls, dummies, emission_dummies) for i in range(k_max))
        exp_on_med,med_on_out,beta_tilde,bootsamples,total_effect,boot_propMed = [],[],[],[],[],[]
        bootstrappedSAMPLEVARS = pd.DataFrame()
        for n in range(len(OUT)):                
            exp_on_med.append(OUT[n][0])
            med_on_out.append(OUT[n][1])
            beta_tilde.append(OUT[n][2])
            bootsamples.append(OUT[n][3])
            total_effect.append(OUT[n][4])
            boot_propMed.append(OUT[n][8])
            bootstrappedSAMPLEVARS = pd.concat((bootstrappedSAMPLEVARS, pd.DataFrame(OUT[n][9]).transpose()))
            
            
        boot_sorted = np.sort(bootsamples)
        boot_propMed = np.sort(boot_propMed)

        conf = 0.05 ### Confidence level
        mediation_mean = np.mean(boot_sorted)
        lower_b, upper_b = self.get_bounds(pd.DataFrame(boot_sorted), k_max, conf)
        prop = np.mean(boot_propMed)
        lower_b_r, upper_b_r = self.get_bounds(pd.DataFrame(boot_propMed), k_max, conf)
        proportion = str(np.round(prop, 3))+' ['+str(np.round(lower_b_r, 3))+','+str(np.round(upper_b_r, 3))+']'

        return boot_sorted, mediation_mean,lower_b, upper_b, exp_on_med, med_on_out, k_max, proportion, beta_tilde, total_effect, boot_propMed, bootstrappedSAMPLEVARS


    

    def parallelised_beta(self, iteration_, dt_, dep, ind,truncated_controls,dummies,emission_dummies):
        sample = np.random.choice(dt_.index, int(0.9*len(dt_)))
        dtB = dt_.iloc[sample]
        mod = sm.RLM(dtB[dep], dtB[[ind]+truncated_controls+dummies+emission_dummies], M=sm.robust.norms.HuberT()).fit()
        standardiser = dtB[ind].std()/dtB[dep].std()
        return mod.params.loc[ind]*standardiser
       

    def parallel_single_effect(self, X, dep, ind, controls_, dummies, emission_dummies):
        dt_ = X.copy()
        dt_ =  dt_.dropna().reset_index(drop=True)
        dt_ = self.filter_outlier(dt_, var_to_filter = [dep, ind] + controls_)
        efx = []
        truncated_controls = controls_.copy()

        k_max = 1000
        efx = Parallel(n_jobs=6)(delayed(self.parallelised_beta)(i, dt_,dep, ind,truncated_controls,dummies,emission_dummies) for i in range(k_max))
    
        
        efx_sorted = np.sort(efx)
        conf = 0.05 ### Confidence level
        lower_b, upper_b = self.get_bounds(pd.DataFrame(efx_sorted), k_max, conf)

        return efx_sorted, np.mean(efx_sorted), lower_b, upper_b
    
    
    


    def get_bounds(self, X_sorted, k_max, conf):
        lower_b_ = X_sorted.iloc[int(k_max*0.5*conf-1)]
        upper_b_ = X_sorted.iloc[int((1-0.5*conf)*k_max)]
        lower_b = np.min([lower_b_, upper_b_])
        upper_b = np.max([lower_b_, upper_b_])

        return lower_b, upper_b



    



    def effect_on_emissions(self, em_type,tc,  golden_data, panel_data, controls_, dummies, initiative_set, scaler_type='relative',  mediating_controls = [], strategy_type = None,INTERACTIONTYPE = 'COMPETITOR',number_of_lags = 2, make_total_effect=False, zero_remove=False):
        
        if strategy_type == 'Size':
            print('Running baseline mediation with sales changes')
            x_, dep1, dep2, controls, emission_dummies = self.shape_data_BASELINE(em_type, tc, golden_data, panel_data, controls_, dummies, initiative_set, scaler_type, mediating_controls, strategy_type, number_of_lags=number_of_lags)
        else:
            x_, dep1, dep2, controls, emission_dummies = self.shape_data(em_type, tc, golden_data, panel_data, controls_, dummies, initiative_set, scaler_type, mediating_controls, strategy_type, number_of_lags=number_of_lags)
        x_ = x_.loc[:,~x_.columns.duplicated()].copy()
        
        
        if zero_remove:
            s = x_.sum()
            s = list(s[s == 0].index)
            if len(s) > 0:
                if 'COMPETITOR' in s: s.remove('COMPETITOR')
                x_ = x_.drop(columns = s)

        med = 0
        print('Product of coefficients with bootstrapped CI', len(x_.dropna()))
        bootsamples_sorted, mediation_mean,  lower_b, upper_b, exp_on_med, med_on_out, k_max, proportion, beta_tilde, total_effect, boot_propMed, bootstrappedSAMPLEVARS = self.bootstrapped_parallelised(x_, dep1, dep2, INTERACTIONTYPE, controls, mediating_controls ,  dummies, emission_dummies)
        
        TotalDist, efx, efxL, efxU = -1, -1,-1,-1
        return x_, med, TotalDist, efx, efxL, efxU, mediation_mean, lower_b, upper_b, exp_on_med, med_on_out, k_max, proportion, beta_tilde, total_effect, boot_propMed, bootstrappedSAMPLEVARS
    
    
    def get_total_effect(self, em_type,tc,  golden_data, panel_data, controls_, dummies, initiative_set, scaler_type='relative',  mediating_controls = [], strategy_type = None,INTERACTIONTYPE = 'COMPETITOR',number_of_lags = 2, zero_remove=False):
        x_, dep1, dep2, controls, emission_dummies = self.shape_data(em_type, tc, golden_data, panel_data, controls_, dummies, initiative_set, scaler_type, mediating_controls, strategy_type, number_of_lags=number_of_lags)
        x_ = x_.loc[:,~x_.columns.duplicated()].copy()

        if zero_remove:
            s = x_.sum()
            s = list(s[s == 0].index)
            if len(s) > 0:
                if 'COMPETITOR' in s: s.remove('COMPETITOR')
                x_ = x_.drop(columns = s)
        controls = ['firm_size', 'Tangibility', 'MTB',  'turnover']
        TotalDist, efx, efxL, efxU = self.parallel_single_effect(x_, dep1, INTERACTIONTYPE, controls, dummies, emission_dummies)
        return TotalDist,  efx, efxL, efxU

    def get_R2(self, X, Y):
        return pearsonr(X, Y)[0]**2
    def Make_full_regression(self, dt_X, outcome_var, mediator_var, Exposure, controls, mid_controls, dummies, emission_dummies):
        dt_ = dt_X.copy()
        truncated_controls = controls.copy()
        if mediator_var in truncated_controls:
            truncated_controls.remove(mediator_var)
        dt_ =  dt_.dropna().reset_index() #(drop=True)
        dt_ = self.filter_outlier(dt_, var_to_filter = [Exposure, outcome_var]+truncated_controls) 
        dt_ = dt_.reset_index(drop=True)
        phi_controls = truncated_controls.copy()
        phi_controls.remove('firm_size_t1')
        mediator_ = sm.RLM(dt_[mediator_var], dt_[[Exposure]+phi_controls+['IMR']+dummies], M=sm.robust.norms.HuberT()).fit()
        outcome_ = sm.RLM(dt_[outcome_var], dt_[[Exposure, mediator_var]+truncated_controls+mid_controls+dummies+emission_dummies], M=sm.robust.norms.HuberT()).fit()
        partial_ = sm.RLM(dt_[outcome_var], dt_[[Exposure, mediator_var]+truncated_controls+dummies], M=sm.robust.norms.HuberT()).fit() 

        return mediator_, outcome_, partial_, dt_, mediator_var, outcome_var

    def FULLREGRESSION(self, em_type,tc,  golden_data, panel_data, controls_, dummies, initiative_set, scaler_type='relative',  mediating_controls = [], strategy_type = None,INTERACTIONTYPE = 'COMPETITOR',number_of_lags = 2, zero_remove=False, specific_emission_type=False):
        
        if strategy_type == 'Size':
            print('Running baseline mediation with sales changes')
            x_, dep1, dep2, controls, emission_dummies = self.shape_data_BASELINE(em_type, tc, golden_data, panel_data, controls_, dummies, initiative_set, scaler_type, mediating_controls, strategy_type, number_of_lags=number_of_lags)
        else:
            x_, dep1, dep2, controls, emission_dummies = self.shape_data(em_type, tc, golden_data, panel_data, controls_, dummies, initiative_set, scaler_type, mediating_controls, strategy_type, number_of_lags=number_of_lags)
        x_ = x_.loc[:,~x_.columns.duplicated()].copy()

        if zero_remove:
            s = x_.sum()
            s = list(s[s == 0].index)
            if len(s) > 0:
                if 'COMPETITOR' in s: s.remove('COMPETITOR')
                x_ = x_.drop(columns = s)
        print(controls)
        print('Product of coefficients with bootstrapped CI')
        mediator_, outcome_, partial_, dt_, mediator_var, outcome_var = self.Make_full_regression(x_, dep1, dep2, INTERACTIONTYPE, controls, mediating_controls ,  dummies, emission_dummies)
        #=============================
        return mediator_, outcome_, partial_, dt_, mediator_var, outcome_var, x_


